🚧Latent Variable Models

Factor Analysis
Models that relate a set of observable variables to a set of unobserved (latent) variables.

General Principles

In some scenarios, the observed data does not directly reflect the underlying structure or factors influencing the outcome. Instead, latent variables—variables that are not directly observed but are inferred from the data—can help model this hidden structure. These latent variables capture unobserved factors that affect the relationship between predictors (X) and the outcome (Y).

We model the relationship between the predictor variables (X) and the outcome variable (Y) with a latent variable (Z) as follows:

Y = f(X, Z) + \epsilon

Where: - Y is the observed outcome variable. - X is the observed predictor variable(s). - Z is the latent (unobserved) variable, which we aim to infer. - f(X, Z) is the function that relates X and Z to Y. - is the error term, typically assumed to be normally distributed with mean 0 and variance ^2.

The latent variable Z can represent various phenomena, such as group-level effects, time-varying trends, or individual-level factors, that are not captured by the observed predictors alone.

Considerations

In Bayesian regression with latent variables, we consider the uncertainty in both the observed and latent variables. We declare prior distributions for the latent variables, in addition to the usual priors for regression coefficients and intercepts. These latent variables are often modeled using Gaussian distributions (Normal priors) or more flexible distributions such as Multivariate Normal for correlations among the latent variables.

The goal is to infer the posterior distribution over both the parameters and the latent variables, given the observed data.

Example

Below is an example code snippet demonstrating Bayesian regression with latent variables using TensorFlow Probability:

Code
from BI import bi
import jax.numpy as jnp


# Setup device------------------------------------------------
m = bi(platform='cpu')

# Data Simulation ------------------------------------------------
NY = 4  # Number of dependent variables or outcomes (e.g., dimensions for latent variables)
NV = 8  # Number of observations or individual-level data points (e.g., subjects)
N = 100
K = 5
a = 0.5
# Generate the means and offsets for the data
# means: Generate random normal means for each of the NY outcomes
# offsets: Generate random normal offsets for each of the NV observations
means = m.dist.normal(0, 1, shape=(NY,), sample=True, seed=10)
offsets = m.dist.normal(0, 1, shape=(NV, 1), sample=True, seed=20)

Y2 = offsets + means

# Simulate individual-level random effects (e.g., random slopes or intercepts)
# b_individual: A matrix of size (N, K) where N is the number of individuals and K is the number of covariates
b_individual = m.dist.normal(0, 1, shape=(N, K), sample=True, seed=0)

# mu: Add an additional effect 'a' to the individual-level random effects 'b_individual'
# 'a' could represent a population-level effect or a baseline
mu = b_individual + a

# Convert Y2 to a JAX array for further computation in a JAX-based framework
Y2 = jnp.array(Y2)


# Set data ------------------------------------------------
dat = dict(
    NY = NY,
    NV = NV,
    Y2 = Y2
)
m.data_on_model = dat

# Define model ------------------------------------------------
def model(NY, NV, Y2):
    means = m.dist.normal(0, 1, shape=(NY,), name='means')
    offset = m.dist.normal(0, 1, shape=(NV, 1), name='offset')
    sigma = m.dist.exponential(1, shape=(NY,), name='sigma')
    tmp = jnp.tile(means, (NV, 1)).reshape(NV, NY)
    mu_l = tmp + offset
    m.dist.normal(mu_l, jnp.tile(sigma, [NV, 1]), obs=Y2)

# Run sampler ------------------------------------------------
m.fit(model, progress_bar=False)

# Summary ------------------------------------------------
m.summary()
BI v 0.0.26 package loaded
E0426 12:29:27.527452 2207075 cuda_dnn.cc:523] Loaded runtime CuDNN library: 9.1.0 but source was compiled with: 9.8.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E0426 12:29:27.529380 2207075 cuda_dnn.cc:523] Loaded runtime CuDNN library: 9.1.0 but source was compiled with: 9.8.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
jax.local_device_count 32
mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
means[0] -1.51 0.17 -1.65 -1.21 0.08 0.07 4.28 11.32 361.95
means[1] -0.79 0.17 -0.92 -0.49 0.08 0.07 4.29 11.32 357.41
means[2] 1.18 0.17 1.05 1.48 0.08 0.07 4.28 11.32 361.30
means[3] -0.66 0.17 -0.80 -0.37 0.08 0.07 4.28 11.32 361.95
offset[0, 0] -0.58 0.17 -0.87 -0.44 0.08 0.07 4.28 11.32 361.95
offset[1, 0] 0.81 0.17 0.52 0.95 0.08 0.07 4.28 11.32 361.95
offset[2, 0] -1.08 0.17 -1.37 -0.94 0.08 0.07 4.28 11.32 361.96
offset[3, 0] 0.16 0.17 -0.14 0.29 0.08 0.07 4.28 11.32 361.96
offset[4, 0] -0.23 0.17 -0.53 -0.10 0.08 0.07 4.28 11.32 361.96
offset[5, 0] -0.74 0.17 -1.04 -0.61 0.08 0.07 4.28 11.32 361.94
offset[6, 0] 0.23 0.17 -0.06 0.37 0.08 0.07 4.28 11.32 361.96
offset[7, 0] 0.81 0.17 0.51 0.95 0.08 0.07 4.28 11.32 361.95
sigma[0] 0.00 0.00 0.00 0.00 0.00 0.00 5.98 11.36 2.92
sigma[1] 0.00 0.00 0.00 0.00 0.00 0.00 5.08 10.70 25.33
sigma[2] 0.00 0.00 0.00 0.00 0.00 0.00 4.99 11.33 24.21
sigma[3] 0.00 0.00 0.00 0.00 0.00 0.00 5.87 12.89 1.77

Mathematical Details